Skip to content

Conversation

@tc-huang
Copy link
Contributor

@tc-huang tc-huang commented Oct 10, 2025

Note

Work in Progress !
This PR is my attempt to address issue #2061. As it's a complex issue and the PR is still under development, all collaboration, help, and suggestions would be greatly appreciated!

What this does

(WIP) Explain what this PR does.

W1018 23:02:25.828000 103342 torch/_dynamo/variables/tensor.py:913] [0/0] Graph break from `Tensor.item()`, consider setting:
W1018 23:02:25.828000 103342 torch/_dynamo/variables/tensor.py:913] [0/0]     torch._dynamo.config.capture_scalar_outputs = True
W1018 23:02:25.828000 103342 torch/_dynamo/variables/tensor.py:913] [0/0] or:
W1018 23:02:25.828000 103342 torch/_dynamo/variables/tensor.py:913] [0/0]     env TORCHDYNAMO_CAPTURE_SCALAR_OUTPUTS=1
W1018 23:02:25.828000 103342 torch/_dynamo/variables/tensor.py:913] [0/0] to include these operations in the captured graph.
W1018 23:02:25.828000 103342 torch/_dynamo/variables/tensor.py:913] [0/0] 
W1018 23:02:25.828000 103342 torch/_dynamo/variables/tensor.py:913] [0/0] Graph break: from user code at:
W1018 23:02:25.828000 103342 torch/_dynamo/variables/tensor.py:913] [0/0]   File "/home/tc-huang/Desktop/lerobot/src/lerobot/policies/act/modeling_act.py", line 147, in forward
W1018 23:02:25.828000 103342 torch/_dynamo/variables/tensor.py:913] [0/0]     loss_dict = {"l1_loss": l1_loss.item()}
W1018 23:02:25.828000 103342 torch/_dynamo/variables/tensor.py:913] [0/0] 
W1018 23:02:25.828000 103342 torch/_dynamo/variables/tensor.py:913] [0/0] 
  • src/lerobot/policies/act/modeling_act.py
    -    loss_dict = {"l1_loss": l1_loss.item()}
    +    loss_dict = {"l1_loss": l1_loss}
         if self.config.use_vae:
              # Calculate Dₖₗ(latent_pdf || standard_normal). Note: After computing the KL-divergence for
              # each dimension independently, we sum over the latent dimension to get the total
              # KL-divergence per batch element, then take the mean over the batch.
              # (See App. B of https://huggingface.co/papers/1312.6114 for more details).
              mean_kld = (
                  (-0.5 * (1 + log_sigma_x2_hat - mu_hat.pow(2) - (log_sigma_x2_hat).exp())).sum(-1).mean()
              )
    -         loss_dict["kld_loss"] = mean_kld.item()
    +         loss_dict["kld_loss"] = mean_kld
              loss = l1_loss + mean_kld * self.config.kl_weight
          else:
              loss = l1_loss
    
          return loss, loss_dict
  • src/lerobot/scripts/lerobot_train.py
     with accelerator.autocast():
         loss, output_dict = policy.forward(batch)
    +    output_dict = {k: v.item() if isinstance(v, torch.Tensor) else v for k, v in output_dict.items()}

Current progress

  • ✅ Compilation: PASSED

  • ❌ Correctness: FAILED

  • ❌ Performance: FAILED

  • ✅ Benchmarking: PASSED

    When compiling the policy with torch.compile(policy, mode="default"), the loss_diff of training test within test_correctness fails to meet the 1e-5 tolerance. Using either torch.compile(policy, mode="default", options={"fallback_random": True}) or torch.compile(policy, mode="default", backend="aot_eager") satisfies the tolerance, but neither configuration provides a noticeable performance improvement.

How it was tested

(WIP) Explain/show how you tested your changes.

  • Baseline and final benchmark reports
- lerobot version: 0.3.4
- Platform: Linux-6.8.0-84-generic-x86_64-with-glibc2.35
- Python version: 3.10.12
- Huggingface Hub version: 0.35.1
- Datasets version: 4.1.1
- Numpy version: 2.2.6
- PyTorch version: 2.7.1+cu126
- Is PyTorch built with CUDA support?: True
- Cuda version: 12.6
- GPU model: NVIDIA GeForce RTX 4090
- Using GPU in script?: Yes
python benchmarks/policies_compilation/benchmark_inference_compile_lerobot.py --policy act --device cuda --output benchmarks/policies_compilation/baseline_act_report.md

How to checkout & try? (for the reviewer)

(WIP) Provide a simple way for the reviewer to try out your changes.

python benchmarks/policies_compilation/benchmark_inference_compile_lerobot.py --policy act --device cuda --output benchmarks/policies_compilation/baseline_act_report.md

…pile graph break

Removed .item() calls from loss_dict in forward() to avoid breaking the
torch.compile computation graph. The tensor-to-scalar conversion is now
handled in the training script instead.
In the inference benchmark in `benchmark_inference_compile_lerobot.py`,
the `test_correctness` method failed to properly compare compiled
`ACTPolicy` inference results. This was due to `policy_original`
and `policy_compiled` sharing the same `_action_queue` object.

Previously, the call order was:
1. `policy_original.reset()`
2. `policy_compiled.reset()`
3. `policy_original.select_action()`
4. `policy_compiled.select_action()`

Because the `_action_queue` is shared,
`policy_original.select_action()` would run inference
(`predict_action_chunk`) and extend the queue with `n_action_steps`
actions. `policy_compiled.select_action()` would then find a
non-empty queue and simply pop an action, bypassing its own
compiled inference logic.

This commit reorders the calls to:
1. `policy_original.reset()`
2. `policy_original.select_action()`
3. `policy_compiled.reset()`
4. `policy_compiled.select_action()`

This change ensures that `policy_compiled.reset()` clears the shared
queue *after* the original policy's action selection. Consequently,
`policy_compiled.select_action()` finds an empty queue and executes
its own compiled inference, allowing for a correct comparison.

With this fix, the compiled `ACTPolicy` inference check within
`test_correctness` now passes, validating that the compiled
inference output matches the original.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant